# encoding=utf-8

import os

import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as layers


def reduce(feature_map):
    """
    :param feature_map: [batch_size, height, width, depth]
    :return:
    """
    static_shape = feature_map.shape.as_list()
    dynamic_shape = tf.shape(feature_map)
    shape = [d if d else dynamic_shape[i] for i, d in enumerate(static_shape)]
    padding = tf.stack([[0, 0],
                        [0, shape[1] % 2],
                        [0, shape[2] % 2],
                        [0, 0]])
    feature_map = tf.pad(feature_map, padding)
    new_shape_ = tf.stack([shape[0], (1 + shape[1]) // 2, 2, (1 + shape[2]) // 2, 2, shape[3]])
    feature_map = tf.reshape(feature_map, new_shape_)
    feature_map = tf.transpose(feature_map, [0, 1, 3, 2, 4, 5])
    new_shape = tf.stack([shape[0], (1 + shape[1]) // 2, (1 + shape[2]) // 2, 2 * 2 * shape[3]])
    feature_map = tf.reshape(feature_map, new_shape)
    return feature_map


def expand(feature_map):
    """
    :param feature_map: [batch_size, height, width, depth]
    :return:
    """
    static_shape = feature_map.shape.as_list()
    dynamic_shape = tf.shape(feature_map)
    shape = [d if d else dynamic_shape[i] for i, d in enumerate(static_shape)]
    new_shape_ = tf.stack([shape[0], shape[1], shape[2], 2, 2, shape[3] // 4])
    feature_map = tf.reshape(feature_map, new_shape_)
    feature_map = tf.transpose(feature_map, [0, 1, 3, 2, 4, 5])
    new_shape = tf.stack([shape[0], shape[1] * 2, shape[2] * 2, shape[3] // 4])
    feature_map = tf.reshape(feature_map, new_shape)
    return feature_map


def conv_cond_concat(x, y):
    """Concatenate conditioning vector on feature map axis.
    :param x:
    :param y:
    :return:
    """
    x_shapes = x.get_shape()
    y_shapes = y.get_shape()
    return tf.concat([x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)


def build_generator(z, y=None, reuse=False):
    with tf.variable_scope("generator") as scope:
        if reuse:
            scope.reuse_variables()

        if y is not None:
            z = tf.concat((z, y), axis=-1)
            yy = tf.reshape(y, [-1, 1, 1, 10])

        fc0 = layers.fully_connected(z, 4 * 4 * 256, activation_fn=tf.nn.leaky_relu)
        conv1 = tf.reshape(fc0, [-1, 4, 4, 256])  # 2x2
        conv1 = layers.batch_norm(conv1, updates_collections="generator", reuse=reuse,
                                  scope="batch_norm1")
        if y is not None:
            conv1 = conv_cond_concat(conv1, yy)

        # conv2 = layers.conv2d(conv1, 4 * 64, kernel_size=3, activation_fn=tf.nn.leaky_relu)
        # conv2 = expand(conv2)  # 4x4
        # conv2 = layers.batch_norm(conv2, updates_collections="generator", reuse=reuse,
        #                           scope="batch_norm2")
        # if y is not None:
        #     conv2 = conv_cond_concat(conv2, yy)

        conv3 = layers.conv2d(conv1, 4 * 128, kernel_size=3, activation_fn=tf.nn.leaky_relu)
        conv3 = expand(conv3)  # 8x8
        conv3 = layers.batch_norm(conv3, updates_collections="generator", reuse=reuse,
                                  scope="batch_norm3")
        if y is not None:
            conv3 = conv_cond_concat(conv3, yy)

        conv4 = layers.conv2d(conv3, 4 * 32, kernel_size=3, activation_fn=tf.nn.leaky_relu)
        conv4 = expand(conv4)  # 16x16
        conv4 = layers.batch_norm(conv4, updates_collections="generator", reuse=reuse,
                                  scope="batch_norm4")
        if y is not None:
            conv4 = conv_cond_concat(conv4, yy)

        conv5 = layers.conv2d(conv4, 4 * 1, kernel_size=3, activation_fn=tf.nn.sigmoid)
        conv5 = expand(conv5)  # 32x32
        output = tf.slice(conv5, [0, 2, 2, 0], [-1, 28, 28, -1])

    return output


def build_discriminator(x, y=None, reuse=False):
    with tf.variable_scope("discriminator") as scope:
        if reuse:
            scope.reuse_variables()

        y_dim = 0
        if y is not None:
            yy = tf.reshape(y, [-1, 1, 1, 10])
            y_dim = 10

        x = 2 * x - 0.5
        if y is not None:
            x = conv_cond_concat(x, yy)

        conv1 = layers.conv2d(x, 128, kernel_size=5, activation_fn=tf.nn.leaky_relu)
        conv1 = reduce(conv1)  # 14x14
        conv1 = layers.batch_norm(conv1, updates_collections="discriminator", reuse=reuse,
                                  scope="batch_norm1")
        if y is not None:
            conv1 = conv_cond_concat(conv1, yy)

        conv2 = layers.conv2d(conv1, 256, kernel_size=3, activation_fn=tf.nn.leaky_relu)
        conv2 = reduce(conv2)  # 7x7
        conv2 = layers.batch_norm(conv2, updates_collections="discriminator", reuse=reuse,
                                  scope="batch_norm2")
        if y is not None:
            conv2 = conv_cond_concat(conv2, yy)

        conv3 = layers.conv2d(conv2, 256, kernel_size=3, activation_fn=tf.nn.leaky_relu)
        conv3 = reduce(conv3)  # 4x4
        conv3 = layers.batch_norm(conv3, updates_collections="discriminator", reuse=reuse,
                                  scope="batch_norm3")
        if y is not None:
            conv3 = conv_cond_concat(conv3, yy)

        flat = layers.flatten(conv3)
        fc4 = layers.fully_connected(flat, 1024, activation_fn=tf.nn.leaky_relu)
        if y is not None:
            fc4 = tf.concat((fc4, y), axis=-1)

        fc5 = layers.fully_connected(fc4, 1, activation_fn=tf.nn.sigmoid)

    return fc5


def build_gan():
    end_points = {}
    batch_size = 32

    global_step = tf.train.get_or_create_global_step()
    update_step_op = tf.assign_add(global_step, 1)

    real_x = tf.placeholder(tf.float32, [batch_size, 28, 28, 1])
    real_y = tf.placeholder(tf.float32, [batch_size, 10])

    random_x = tf.random_normal([batch_size, 100])
    random_y = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
    random_y = tf.one_hot(random_y, 10, dtype=tf.float32)
    fake_x = build_generator(random_x, random_y, reuse=False)

    unstack_fake_x = tf.unstack(fake_x)
    for i in range(min(32, len(unstack_fake_x))):
        tf.summary.image("fake_x/%d" % i, tf.expand_dims(unstack_fake_x[i], axis=0))
    unstack_real_x = tf.unstack(real_x)
    for i in range(min(32, len(unstack_real_x))):
        tf.summary.image("real_x/%d" % i, tf.expand_dims(unstack_real_x[i], axis=0))

    real_prd = build_discriminator(real_x, real_y, reuse=False)
    fake_prd = build_discriminator(fake_x, random_y, reuse=True)

    random_y_ = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32)
    random_y_ = tf.one_hot(random_y_, 10, dtype=tf.float32)
    real_prd_rdm = build_discriminator(real_x, random_y_, reuse=True)
    label_rdm = tf.reduce_sum(random_y_ * real_y, axis=-1, keep_dims=True)

    noise_label = tf.random_uniform(tf.shape(real_prd), minval=0.98, maxval=1.0)
    noise_label_ = tf.random_uniform(tf.shape(real_prd), minval=0.0, maxval=0.02)
    dis_loss_p = -(noise_label * tf.log(real_prd) + (1 - noise_label) * tf.log(1 - real_prd))
    dis_loss_n = -(noise_label_ * tf.log(fake_prd) + (1 - noise_label_) * tf.log(1 - fake_prd))
    dis_loss_rd = -(label_rdm * tf.log(real_prd_rdm) + (1 - label_rdm) * tf.log(1 - real_prd_rdm))
    dis_loss = dis_loss_p + dis_loss_n + dis_loss_rd
    gen_loss = -(tf.ones_like(fake_prd) * tf.log(fake_prd))

    dis_loss_p = tf.reduce_mean(dis_loss_p)
    dis_loss_n = tf.reduce_mean(dis_loss_n)
    dis_loss = tf.reduce_mean(dis_loss)
    gen_loss = tf.reduce_mean(gen_loss)

    t_vars = tf.trainable_variables()
    d_vars = [var for var in t_vars if 'discriminator' in var.name]
    g_vars = [var for var in t_vars if 'generator' in var.name]

    dis_ups = tf.get_collection("discriminator")
    with tf.control_dependencies(dis_ups):
        # train_dis_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss, var_list=d_vars)
        train_dis_p_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss_p,
                                                                                          var_list=d_vars)
        train_dis_n_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=dis_loss_n,
                                                                                          var_list=d_vars)
        train_dis_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5).minimize(loss=dis_loss, var_list=d_vars)
    gen_ups = tf.get_collection("generator")
    with tf.control_dependencies(gen_ups):
        # train_gen_op = tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(loss=gen_loss, var_list=g_vars)
        train_gen_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5).minimize(loss=gen_loss, var_list=g_vars)

    end_points["real_x"] = real_x
    end_points["real_y"] = real_y
    end_points["fake_x"] = fake_x
    end_points["dis_loss"] = dis_loss
    end_points["gen_loss"] = gen_loss
    end_points["train_dis_op"] = train_dis_op
    end_points["train_gen_op"] = train_gen_op
    end_points["global_step"] = global_step
    end_points["update_step_op"] = update_step_op
    end_points["summaries"] = tf.summary.merge_all()
    end_points["dis_loss_p"] = dis_loss_p
    end_points["dis_loss_n"] = dis_loss_n
    end_points["train_dis_p_op"] = train_dis_p_op
    end_points["train_dis_n_op"] = train_dis_n_op

    return end_points


def train():
    xx, yy = load_mnist("data/mnist")
    state = {"st": 0}
    batch_size = 32

    def get_next_batch():
        st = state['st']
        ed = (st + batch_size) % len(xx)
        state['st'] = ed
        if ed > st:
            return xx[st:ed, ...], yy[st:ed, ...]
        else:
            return np.concatenate([xx[st:, ...], xx[:ed, ...]], axis=0), \
                   np.concatenate([yy[st:, ...], yy[:ed, ...]], axis=0)

    end_points = build_gan()

    real_x = end_points["real_x"]
    real_y = end_points["real_y"]
    fake_x = end_points["fake_x"]
    dis_loss = end_points["dis_loss"]
    gen_loss = end_points["gen_loss"]
    train_dis_op = end_points["train_dis_op"]
    train_gen_op = end_points["train_gen_op"]
    global_step = end_points["global_step"]
    update_step_op = end_points["update_step_op"]
    summaries = end_points["summaries"]
    dis_loss_p = end_points["dis_loss_p"]
    dis_loss_n = end_points["dis_loss_n"]
    train_dis_p_op = end_points["train_dis_p_op"]
    train_dis_n_op = end_points["train_dis_n_op"]

    sv = tf.train.Supervisor(logdir="logdir")
    with sv.managed_session() as sess:
        step = sess.run(global_step)
        while step < 200000:
            bx, by = get_next_batch()
            if step % 100 == 0:
                _, dls, fx, sumr = sess.run([train_dis_op, dis_loss, fake_x, summaries],
                                            feed_dict={real_x: bx, real_y: by})
                # _, dls_p = sess.run([train_dis_p_op, dis_loss_p],
                #                     feed_dict={real_x: get_next_batch()})
                # _, dls_n, fx, sumr = sess.run([train_dis_n_op, dis_loss_n, fake_x, summaries])
                # dls = (dls_p + dls_n) / 2
                sv.summary_writer.add_summary(sumr, step)
            else:
                _, dls, fx = sess.run([train_dis_op, dis_loss, fake_x],
                                      feed_dict={real_x: bx, real_y: by})
                # _, dls_p = sess.run([train_dis_p_op, dis_loss_p],
                #                     feed_dict={real_x: get_next_batch()})
                # _, dls_n, fx = sess.run([train_dis_n_op, dis_loss_n, fake_x])
                # dls = (dls_p + dls_n) / 2
            _, gls0, fx = sess.run([train_gen_op, gen_loss, fake_x])
            _, gls1, fx = sess.run([train_gen_op, gen_loss, fake_x])
            _, gls2, fx = sess.run([train_gen_op, gen_loss, fake_x])
            _, step = sess.run([update_step_op, global_step])
            print("%08d:\t%0.8f\t%0.8f" % (step, dls, (gls0 + gls1 + gls2) / 3))


def load_mnist(data_dir):
    fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    trY = loaded[8:].reshape((60000)).astype(np.float)

    fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

    fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
    loaded = np.fromfile(file=fd, dtype=np.uint8)
    teY = loaded[8:].reshape((10000)).astype(np.float)

    trY = np.asarray(trY)
    teY = np.asarray(teY)

    X = np.concatenate((trX, teX), axis=0)
    y = np.concatenate((trY, teY), axis=0).astype(np.int)

    seed = 547
    np.random.seed(seed)
    np.random.shuffle(X)
    np.random.seed(seed)
    np.random.shuffle(y)

    y_vec = np.zeros((len(y), 10), dtype=np.float32)
    for i, label in enumerate(y):
        y_vec[i, y[i]] = 1.0

    return X / 255., y_vec


if __name__ == '__main__':
    train()
